from rest_framework.viewsets import GenericViewSet, ViewSet

from rest_framework.decorators import action
from rest_framework.mixins import ListModelMixin, RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin, \
    CreateModelMixin
from utils.common_pagination import PageNumberPagination
from .models import UserInfo
from .serializer import UserLoginSerializer, UserSerializer, UserCreateUpdateSerializer
from utils.common_response import APIResponse
from .authentication import LoginAuthentication


class UserLoginView(ViewSet):
    '''
    登录注册视图类
    '''

    # @action(methods=['POST'], detail=False)
    # def login(self, request):
    #     '''
    #     如果不使用内置的auth，使用simplejwt会报错，修改源码rest_framework_simplejwt.serializers第三行注释掉
    #     '''
    #     user = UserInfo.objects.all().first()
    #     from rest_framework_simplejwt.tokens import RefreshToken
    #     refresh = RefreshToken.for_user(user)
    #     return APIResponse(token=str(refresh.access_token), username=user.username)

    @action(methods=['POST'], detail=False)
    def login(self, request):
        ser = UserLoginSerializer(data=request.data, context={'request': request})
        ser.is_valid(raise_exception=True)
        token = ser.context.get('token')
        username = ser.context.get('username')
        avatar = ser.context.get('avatar')
        return APIResponse(token=token, username=username, avatar=avatar)


class UserView(GenericViewSet, ListModelMixin, RetrieveModelMixin, DestroyModelMixin, CreateModelMixin,
               UpdateModelMixin):
    '''
      用户操作相关视图类
    '''
    authentication_classes = [LoginAuthentication]
    serializer_class = UserSerializer
    queryset = UserInfo.objects.all().filter(is_delete=False).order_by('id')
    pagination_class = PageNumberPagination

    def get_serializer_class(self):
        if self.action == 'list':
            return UserSerializer
        elif self.action == 'retrieve':
            return UserSerializer
        else:
            return UserCreateUpdateSerializer

    def list(self, request):
        res = super().list(request)
        return APIResponse(data=res.data)

    def retrieve(self, request, *args, **kwargs):
        res = super().retrieve(request)
        return APIResponse(data=res.data)

    @action(methods=['POST'], detail=False)
    def batch_delete(self, request, *args, **kwargs):
        self.get_queryset().filter(id__in=request.data.get('ids', [])).delete()
        return APIResponse()

    @action(methods=['POST'], detail=True)
    def roles(self, request, *args, **kwargs):
        user = self.get_object()
        user.roles.set(request.data.get('roles',[]))
        user.save()
        return APIResponse()

    def create(self, request, *args, **kwargs):
        res = super().create(request, *args, **kwargs)
        return APIResponse(data=res.data)

    def update(self, request, *args, **kwargs):
        res = super().update(request, *args, **kwargs)
        return APIResponse(data=res.data)

    # 重置密码
    @action(methods=['POST'], detail=True)
    def repass(self, request, *args, **kwargs):
        user=self.get_object()
        password=user.make_password('123456')
        user.password=password
        user.save()
        return APIResponse()

